{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# Copyright (c) 2020 NVIDIA. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================\n",
    "\n",
    "from functools import partial\n",
    "from os.path import expanduser, join, abspath, dirname, exists\n",
    "import tarfile\n",
    "\n",
    "from ruamel.yaml import YAML\n",
    "\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr\n",
    "from nemo.collections.asr.helpers import monitor_asr_train_progress\n",
    "from nemo.core import NeuralGraph, OperationMode, DeviceType, SimpleLossLoggerCallback\n",
    "from nemo.utils import logging\n",
    "from nemo.utils.app_state import AppState\n",
    "\n",
    "# Create Neural(Module)Factory, use CPU.\n",
    "nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tutorial II: The advanced functionality\n",
    "\n",
    "In this first part of the Neural Graphs (NGs) tutorial we will focus on a more complex example: training of an End-to-End Convolutional Neural Acoustic Model called JASPER. We will build a \"model graph\" and show how we can nest it into another graphs, how we can freeze/unfreeze modules, use graph configuration and save/load graph checkpoints.\n",
    "\n",
    "#### This part covers the following:\n",
    " * how to nest one graph into another\n",
    " * how to serialize and deserialize a graph\n",
    " * how to export and import serialized graph configuration to/from YAML files\n",
    " * how to save and load graph checkpoints (containing weights of the Trainable NMs)\n",
    " * how to freeze/unfreeze modules in a graph\n",
    " \n",
    "Additionally, we will show how use `AppState` to list all the modules and graphs we have created in the scope of our application.\n",
    "In order to learn more about graph nesting and input/output binding please refer to the first part of the tutorial.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the samples for training JASPER - we will use the data available in NeMo tests.\n",
    "data_folder = abspath(\"../../tests/data/\")\n",
    "logging.info(\"Looking up for test ASR data\")\n",
    "if not exists(join(data_folder, \"asr\")):\n",
    "    logging.info(\"Extracting ASR data to: {0}\".format(join(data_folder, \"asr\")))\n",
    "    tar = tarfile.open(join(data_folder, \"asr.tar.gz\"), \"r:gz\")\n",
    "    tar.extractall(path=data_folder)\n",
    "    tar.close()\n",
    "else:\n",
    "    logging.info(\"ASR data found in: {0}\".format(join(data_folder, \"asr\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set paths to model configuration, manifest and sample files.\n",
    "model_config_file = abspath(\"../asr/configs/jasper_an4.yaml\")\n",
    "manifest_path = join(data_folder, 'asr/tarred_an4/tarred_audio_manifest.json')\n",
    "tarpath = join(data_folder, 'asr/tarred_an4/audio_1.tar')\n",
    "\n",
    "# Open the model config file and get vocabulary.\n",
    "yaml = YAML(typ=\"safe\")\n",
    "with open(expanduser(model_config_file)) as f:\n",
    "    config = yaml.load(f)\n",
    "    \n",
    "# Get labels (vocabulary).\n",
    "vocab = config['labels']\n",
    "vocab_len = len(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate DataLayer that can load the tarred samples.\n",
    "data_layer = nemo_asr.TarredAudioToTextDataLayer(\n",
    "    audio_tar_filepaths=tarpath, manifest_filepath=manifest_path, labels=vocab, batch_size=16)\n",
    "logging.info(\"Loaded {} samples that we will use for training\".format(len(data_layer)))\n",
    "\n",
    "# Create rest of the modules using the Neural Module deserialization feature.\n",
    "data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.deserialize(config[\"AudioToMelSpectrogramPreprocessor\"])\n",
    "\n",
    "jasper_encoder = nemo_asr.JasperEncoder.deserialize(config[\"JasperEncoder\"])\n",
    "jasper_decoder = nemo_asr.JasperDecoderForCTC.deserialize(\n",
    "    config[\"JasperDecoderForCTC\"], overwrite_params={\"num_classes\": vocab_len}\n",
    ")\n",
    "ctc_loss = nemo_asr.CTCLossNM(num_classes=vocab_len)\n",
    "greedy_decoder = nemo_asr.GreedyCTCDecoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the Jasper \"model\" graph.\n",
    "with NeuralGraph(operation_mode=OperationMode.both, name=\"jasper_model\") as jasper_model:\n",
    "    # Copy one input port definitions - using \"user\" port names.\n",
    "    jasper_model.inputs[\"input\"] = data_preprocessor.input_ports[\"input_signal\"]\n",
    "    # Bind selected inputs - bind other using the default port name.\n",
    "    i_processed_signal, i_processed_signal_len = data_preprocessor(input_signal=jasper_model.inputs[\"input\"], length=jasper_model)\n",
    "    i_encoded, i_encoded_len = jasper_encoder(audio_signal=i_processed_signal, length=i_processed_signal_len)\n",
    "    i_log_probs = jasper_decoder(encoder_output=i_encoded)\n",
    "    # Bind selected outputs - using \"user\" port names.\n",
    "    jasper_model.outputs[\"log_probs\"] = i_log_probs\n",
    "    jasper_model.outputs[\"encoded_len\"] = i_encoded_len\n",
    "\n",
    "# Print the summary.\n",
    "logging.info(jasper_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Serialize the whole graph.\n",
    "serialized_jasper = jasper_model.serialize()\n",
    "logging.info(\"Serialized JASPER model:\\n {}\".format(serialized_jasper))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can also serialize/deserialize a single NeuralModule, e.g. a decoder.\n",
    "logging.info(\"Serialized JASPER Decoder:\\n {}\".format(jasper_decoder.serialize()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can also export the serialized configuration to a file.\n",
    "jasper_model.export_to_config(\"my_jasper.yml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the lists of graph and modules.\n",
    "logging.info(AppState().graphs.summary())\n",
    "logging.info(AppState().modules.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Deserialize graph - create a copy of the JASPER \"model\".\n",
    "# Please note that the modules exist, so we must enable the graph to \"reuse\" them.\n",
    "# (Commenting out reuse_existing_modules will raise a KeyError.)\n",
    "jasper_copy = NeuralGraph.deserialize(serialized_jasper, reuse_existing_modules=True)\n",
    "serialized_jasper_copy = jasper_copy.serialize()\n",
    "assert serialized_jasper == serialized_jasper_copy # THE SAME! Please note name of the graph is not exported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alternativelly, import a copy of the JASPER \"model\" from config.\n",
    "jasper_copy = NeuralGraph.import_from_config(\"my_jasper.yml\", reuse_existing_modules=True, name=\"jasper_copy\")\n",
    "\n",
    "# Print the summary.\n",
    "logging.info(jasper_copy.summary())\n",
    "\n",
    "# Display list of graph and modules\n",
    "logging.info(AppState().graphs.summary())\n",
    "logging.info(AppState().modules.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that there are two graphs in the \"Graph Registry\", yet the list of modules haven't changed. This means that both graphs are spanned on the same list of modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the \"training\" graph.\n",
    "with NeuralGraph(operation_mode=OperationMode.training) as training_graph:\n",
    "    # Create the \"implicit\" training graph.\n",
    "    o_audio_signal, o_audio_signal_len, o_transcript, o_transcript_len = data_layer()\n",
    "    # Use Jasper module as any other neural module.\n",
    "    o_log_probs, o_encoded_len = jasper_copy(input=o_audio_signal, length=o_audio_signal_len)\n",
    "    o_predictions = greedy_decoder(log_probs=o_log_probs)\n",
    "    o_loss = ctc_loss(\n",
    "        log_probs=o_log_probs, targets=o_transcript, input_length=o_encoded_len, target_length=o_transcript_len\n",
    "    )\n",
    "    # Set the graph output.\n",
    "    training_graph.outputs[\"o_loss\"] = o_loss\n",
    "\n",
    "# Print the summary.\n",
    "logging.info(training_graph.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a simple loss callback.\n",
    "loss_callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[training_graph.output_tensors[\"o_loss\"]],\n",
    "    print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1\n",
    ")\n",
    "# Train the graph.\n",
    "nf.train(\n",
    "    training_graph=training_graph,\n",
    "    optimizer=\"novograd\",\n",
    "    callbacks=[loss_callback],\n",
    "    optimization_params={\"max_steps\": 5, \"lr\": 0.01},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please note that the loss is going down. Still, we use only 65 samples, so we cannot really expect the model to be useful ;)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, I can save the graph checkpoint!\n",
    "# Note that optionally you can indicate the names of the modules to be saved.\n",
    "jasper_copy.save_to(\"my_jasper.chkpt\")#, module_names=[\"jasperencoder0\"])\n",
    "# Please note only \"trainable\" modules will be saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can also save the whole training graph - which in this case will result in the same checkpoint...\n",
    "training_graph.export_to_config(\"my_whole_graph.yml\")\n",
    "training_graph.save_to(\"my_whole_graph.chkpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, I can load everything and continue training.\n",
    "new_training_graph = NeuralGraph.import_from_config(\"my_whole_graph.yml\", reuse_existing_modules=True)\n",
    "\n",
    "# Let's restore only the encoder\n",
    "new_training_graph.restore_from(\"my_whole_graph.chkpt\", module_names=[\"jasperencoder0\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# So let us freeze the whole graph...\n",
    "training_graph.freeze() #we can also freeze a subset, using \"module_names=[]\"\"\n",
    "# ... and finetune only the decoder.\n",
    "training_graph.unfreeze(module_names=[\"jasperdecoderforctc0\"])\n",
    "\n",
    "# Ok, let us see what the graph looks like now.\n",
    "logging.info(training_graph.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Create a new simple callback using graph outputs \"o_loss\".\n",
    "loss_callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[new_training_graph.output_tensors[\"o_loss\"]],\n",
    "    print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1\n",
    ")\n",
    "\n",
    "# And continue training...\n",
    "nf.reset_trainer()\n",
    "nf.train(\n",
    "    training_graph=new_training_graph,\n",
    "    optimizer=\"novograd\",\n",
    "    callbacks=[loss_callback],\n",
    "    optimization_params={\"max_steps\": 5, \"lr\": 0.01},\n",
    ")\n",
    "# Please note that this will throw an error if you will freeze all the trainable modules!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nemo-env",
   "language": "python",
   "name": "nemo-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}